Setup¶
In [ ]:
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from scipy.linalg import svd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
In [ ]:
!pip install torchvision
Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (2.0.2) Requirement already satisfied: torch==2.6.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (2.6.0+cu124) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1) Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (3.18.0) Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (4.13.2) Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (3.4.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (3.1.6) Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (2025.3.2) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.127) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.127) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.127) Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (9.1.0.70) Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.5.8) Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (11.2.1.3) Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (10.3.5.147) Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (11.6.1.9) Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.3.1.170) Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (0.6.2) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (2.21.5) Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.127) Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (12.4.127) Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (3.2.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchvision) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch==2.6.0->torchvision) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.6.0->torchvision) (3.0.2)
In [ ]:
from google.colab import auth
auth.authenticate_user()
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
Dataset¶
Testing (Evaluation)¶
In [ ]:
def get_eval_data(scale=4, patch_size=64):
filepath = "/content/drive/MyDrive/CAKRes/DatasetRe16k/test/nskt_Re16000.h5"
file = h5py.File(filepath, 'r')
dataset = file['fields']
n_skip = 2
n_size = int(2048/n_skip)
t_size = 100
display(dataset.shape)
u = dataset[::, 0, ::n_skip, ::n_skip]
v = dataset[::, 1, ::n_skip, ::n_skip]
w = dataset[::, 2, ::n_skip, ::n_skip]
reshaped_u = u.reshape(t_size, n_size * n_size).T
reshaped_v = v.reshape(t_size, n_size * n_size).T
reshaped_w = w.reshape(t_size, n_size * n_size).T
# Mean and Fluctuations
mean_u = np.mean(reshaped_u, axis=1).reshape(-1, 1)
mean_v = np.mean(reshaped_v, axis=1).reshape(-1, 1)
mean_w = np.mean(reshaped_w, axis=1).reshape(-1, 1)
eval_dataset = FluidFlowDataset(u, v, w, scale=scale, patch_size=patch_size)
return eval_dataset
eval_dataset = get_eval_data()
In [ ]:
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)
Training¶
Here we just use 1000 sample to test the baseline model (due to limited resources as well)
In [ ]:
filepath = "/content/drive/MyDrive/CAKRes/DatasetRe16k/train/nskt_Re16000.h5"
# Open the H5 file
file = h5py.File(filepath, 'r')
#inspect file structure
def print_structure(name, obj):
if isinstance(obj, h5py.Group):
print(f"Group: {name}")
elif isinstance(obj, h5py.Dataset):
print(f"Dataset: {name}, Shape: {obj.shape}, Type: {obj.dtype}")
file.visititems(print_structure)
Dataset: fields, Shape: (1000, 3, 2048, 2048), Type: float64
In [ ]:
dataset = file['fields']
t_skip = 5 #reduce time by this factor
n_skip = 2 #reduce u, v, and w grid by this factor
t_size = int(1000/t_skip)
n_size = int(2048/n_skip)
#load the dataset into a numpy array
u = dataset[::t_skip, 0, ::n_skip, ::n_skip]
v = dataset[::t_skip, 1, ::n_skip, ::n_skip]
w = dataset[::t_skip, 2, ::n_skip, ::n_skip]
In [ ]:
# Reshape the u,v,w data: (time, grid, grid) -> (grid x grid, time)
reshaped_u = u.reshape(t_size, n_size * n_size).T
reshaped_v = v.reshape(t_size, n_size * n_size).T
reshaped_w = w.reshape(t_size, n_size * n_size).T
# Mean and Fluctuations
mean_u = np.mean(reshaped_u, axis=1).reshape(-1, 1)
mean_v = np.mean(reshaped_v, axis=1).reshape(-1, 1)
mean_w = np.mean(reshaped_w, axis=1).reshape(-1, 1)
In [ ]:
display(u.shape)
(200, 1024, 1024)
In [ ]:
!pip install basicsr
Collecting basicsr
Downloading basicsr-1.4.2.tar.gz (172 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/172.5 kB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━ 163.8/172.5 kB 7.8 MB/s eta 0:00:01
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 172.5/172.5 kB 4.5 MB/s eta 0:00:00
Preparing metadata (setup.py) ... done
Collecting addict (from basicsr)
Downloading addict-2.4.0-py3-none-any.whl.metadata (1.0 kB)
Requirement already satisfied: future in /usr/local/lib/python3.11/dist-packages (from basicsr) (1.0.0)
Collecting lmdb (from basicsr)
Downloading lmdb-1.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from basicsr) (2.0.2)
Requirement already satisfied: opencv-python in /usr/local/lib/python3.11/dist-packages (from basicsr) (4.11.0.86)
Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from basicsr) (11.2.1)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from basicsr) (6.0.2)
Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from basicsr) (2.32.3)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.11/dist-packages (from basicsr) (0.25.2)
Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from basicsr) (1.15.2)
Collecting tb-nightly (from basicsr)
Downloading tb_nightly-2.20.0a20250428-py3-none-any.whl.metadata (1.9 kB)
Requirement already satisfied: torch>=1.7 in /usr/local/lib/python3.11/dist-packages (from basicsr) (2.6.0+cu124)
Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from basicsr) (0.21.0+cu124)
Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from basicsr) (4.67.1)
Collecting yapf (from basicsr)
Downloading yapf-0.43.0-py3-none-any.whl.metadata (46 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.8/46.8 kB 3.3 MB/s eta 0:00:00
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (4.13.2)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (2025.3.2)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7->basicsr)
Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7->basicsr)
Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.7->basicsr)
Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.7->basicsr)
Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.7->basicsr)
Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.7->basicsr)
Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.7->basicsr)
Using cached nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.7->basicsr)
Using cached nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.7->basicsr)
Using cached nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (0.6.2)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (12.4.127)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=1.7->basicsr)
Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (3.2.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.7->basicsr) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.7->basicsr) (1.3.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->basicsr) (3.4.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->basicsr) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->basicsr) (2.4.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->basicsr) (2025.1.31)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.11/dist-packages (from scikit-image->basicsr) (2.37.0)
Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.11/dist-packages (from scikit-image->basicsr) (2025.3.30)
Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.11/dist-packages (from scikit-image->basicsr) (24.2)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.11/dist-packages (from scikit-image->basicsr) (0.4)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (1.4.0)
Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (1.71.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (3.8)
Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (5.29.4)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (75.2.0)
Requirement already satisfied: six>1.9 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (1.17.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tb-nightly->basicsr) (3.1.3)
Requirement already satisfied: platformdirs>=3.5.1 in /usr/local/lib/python3.11/dist-packages (from yapf->basicsr) (4.3.7)
Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tb-nightly->basicsr) (3.0.2)
Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)
Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)
Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)
Using cached nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)
Using cached nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)
Using cached nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)
Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Downloading lmdb-1.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (297 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 297.8/297.8 kB 16.3 MB/s eta 0:00:00
Downloading tb_nightly-2.20.0a20250428-py3-none-any.whl (5.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.5/5.5 MB 56.9 MB/s eta 0:00:00
Downloading yapf-0.43.0-py3-none-any.whl (256 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 256.2/256.2 kB 14.6 MB/s eta 0:00:00
Building wheels for collected packages: basicsr
Building wheel for basicsr (setup.py) ... done
Created wheel for basicsr: filename=basicsr-1.4.2-py3-none-any.whl size=214817 sha256=b32ad54c3fa8690b5b054076c71e91fc9bf25346843963a3c1f22341bd16096d
Stored in directory: /root/.cache/pip/wheels/6d/a4/b3/9f888ba88efcae6dd4bbce69832363de9c4051142674f779fa
Successfully built basicsr
Installing collected packages: lmdb, addict, yapf, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, tb-nightly, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, basicsr
Attempting uninstall: nvidia-nvjitlink-cu12
Found existing installation: nvidia-nvjitlink-cu12 12.5.82
Uninstalling nvidia-nvjitlink-cu12-12.5.82:
Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82
Attempting uninstall: nvidia-curand-cu12
Found existing installation: nvidia-curand-cu12 10.3.6.82
Uninstalling nvidia-curand-cu12-10.3.6.82:
Successfully uninstalled nvidia-curand-cu12-10.3.6.82
Attempting uninstall: nvidia-cufft-cu12
Found existing installation: nvidia-cufft-cu12 11.2.3.61
Uninstalling nvidia-cufft-cu12-11.2.3.61:
Successfully uninstalled nvidia-cufft-cu12-11.2.3.61
Attempting uninstall: nvidia-cuda-runtime-cu12
Found existing installation: nvidia-cuda-runtime-cu12 12.5.82
Uninstalling nvidia-cuda-runtime-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82
Attempting uninstall: nvidia-cuda-nvrtc-cu12
Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82
Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82
Attempting uninstall: nvidia-cuda-cupti-cu12
Found existing installation: nvidia-cuda-cupti-cu12 12.5.82
Uninstalling nvidia-cuda-cupti-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82
Attempting uninstall: nvidia-cublas-cu12
Found existing installation: nvidia-cublas-cu12 12.5.3.2
Uninstalling nvidia-cublas-cu12-12.5.3.2:
Successfully uninstalled nvidia-cublas-cu12-12.5.3.2
Attempting uninstall: nvidia-cusparse-cu12
Found existing installation: nvidia-cusparse-cu12 12.5.1.3
Uninstalling nvidia-cusparse-cu12-12.5.1.3:
Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3
Attempting uninstall: nvidia-cudnn-cu12
Found existing installation: nvidia-cudnn-cu12 9.3.0.75
Uninstalling nvidia-cudnn-cu12-9.3.0.75:
Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75
Attempting uninstall: nvidia-cusolver-cu12
Found existing installation: nvidia-cusolver-cu12 11.6.3.83
Uninstalling nvidia-cusolver-cu12-11.6.3.83:
Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83
Successfully installed addict-2.4.0 basicsr-1.4.2 lmdb-1.6.2 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 tb-nightly-2.20.0a20250428 yapf-0.43.0
In [ ]:
from basicsr.archs.swinir_arch import SwinIR
from basicsr.data.transforms import paired_random_crop
from basicsr.utils.img_util import img2tensor, tensor2img
In [ ]:
class FluidFlowDataset(Dataset):
def __init__(self, u, v, w, scale=4, patch_size=64):
self.u = u
self.v = v
self.w = w
self.scale = scale
self.patch_size = patch_size
self.num_samples = u.shape[0]
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Get HR image (combine all three channels)
hr = np.stack([self.u[idx], self.v[idx], self.w[idx]], axis=-1)
# Create LR by downsampling
lr = hr[::self.scale, ::self.scale, :]
# Random crop
hr, lr = paired_random_crop(hr, lr, self.patch_size, self.scale)
# Convert to tensor
hr = img2tensor(hr, bgr2rgb=False, float32=True)
lr = img2tensor(lr, bgr2rgb=False, float32=True)
return {'lq': lr, 'gt': hr}
Loss function¶
In [ ]:
import math
In [ ]:
def _calculate_derivatives(tensor, kernel, padding_val, boundary_mode, spacing):
"""Helper function to calculate derivatives using convolution."""
if boundary_mode != 'zeros':
pad = (padding_val, padding_val, padding_val, padding_val)
tensor_padded = F.pad(tensor, pad, mode=boundary_mode)
conv_padding = 0
else:
tensor_padded = tensor
conv_padding = padding_val
derivative = F.conv2d(tensor_padded, kernel, padding=conv_padding) / spacing
return derivative
In [ ]:
def divergence_loss(predicted_output, dx=2.0*math.pi/2048.0, dy=2.0*math.pi/2048.0, loss_type='l2', boundary_mode='zeros'):
"""Calculates the divergence loss (||du/dx + dv/dy||)."""
if predicted_output.shape[1] < 2:
raise ValueError(f"Divergence loss requires at least 2 channels (u, v), got {predicted_output.shape[1]}")
u_pred = predicted_output[:, 0:1, :, :]
v_pred = predicted_output[:, 1:2, :, :]
kernel_dx = torch.tensor([[[[0., 0., 0.], [-0.5, 0., 0.5], [0., 0., 0.]]]], dtype=predicted_output.dtype, device=predicted_output.device)
kernel_dy = torch.tensor([[[[0., -0.5, 0.], [0., 0., 0.], [0., 0.5, 0.]]]], dtype=predicted_output.dtype, device=predicted_output.device)
padding_val = 1
du_dx = _calculate_derivatives(u_pred, kernel_dx, padding_val, boundary_mode, dx)
dv_dy = _calculate_derivatives(v_pred, kernel_dy, padding_val, boundary_mode, dy)
divergence = du_dx + dv_dy
if loss_type == 'l1':
loss = torch.mean(torch.abs(divergence))
elif loss_type == 'l2':
loss = torch.mean(divergence**2)
else:
raise ValueError("divergence loss_type must be 'l1' or 'l2'")
return loss
In [ ]:
def vorticity_consistency_loss(predicted_output, dx=2.0*math.pi/2048.0, dy=2.0*math.pi/2048.0, loss_type='l2', boundary_mode='zeros'):
"""Calculates vorticity consistency loss (||w_pred - (dv/dx - du/dy)||)."""
if predicted_output.shape[1] != 3:
raise ValueError(f"Vorticity consistency loss requires 3 channels (u, v, w), got {predicted_output.shape[1]}")
u_pred = predicted_output[:, 0:1, :, :]
v_pred = predicted_output[:, 1:2, :, :]
w_pred = predicted_output[:, 2:3, :, :]
kernel_dx = torch.tensor([[[[0., 0., 0.], [-0.5, 0., 0.5], [0., 0., 0.]]]], dtype=predicted_output.dtype, device=predicted_output.device)
kernel_dy = torch.tensor([[[[0., -0.5, 0.], [0., 0., 0.], [0., 0.5, 0.]]]], dtype=predicted_output.dtype, device=predicted_output.device)
padding_val = 1
dv_dx = _calculate_derivatives(v_pred, kernel_dx, padding_val, boundary_mode, dx)
du_dy = _calculate_derivatives(u_pred, kernel_dy, padding_val, boundary_mode, dy)
w_calculated = dv_dx - du_dy
diff = w_pred - w_calculated
if loss_type == 'l1':
loss = torch.mean(torch.abs(diff))
elif loss_type == 'l2':
loss = torch.mean(diff**2)
else:
raise ValueError("vorticity loss_type must be 'l1' or 'l2'")
return loss
In [ ]:
class L1DivergenceLoss(nn.Module):
"""
Calculates a combined loss: standard L1 pixel-wise loss plus a
physics-informed divergence loss.
Loss = L1(predicted, target) + lambda_div * DivergenceLoss(predicted)
"""
def __init__(self, lambda_div=0.01, dx=2.0*math.pi/2048.0, dy=2.0*math.pi/2048.0, div_loss_type='l2', boundary_mode='zeros'):
"""
Args:
lambda_div (float): Weighting factor for the divergence loss term.
dx (float): Grid spacing in the x-direction for divergence calculation.
dy (float): Grid spacing in the y-direction for divergence calculation.
div_loss_type (str): Type of norm ('l1' or 'l2') for divergence loss.
boundary_mode (str): Boundary mode ('zeros', 'reflect', 'replicate')
for derivative calculation in divergence loss.
"""
super().__init__()
self.lambda_div = lambda_div
self.dx = dx
self.dy = dy
self.div_loss_type = div_loss_type
self.boundary_mode = boundary_mode
self.l1_loss = nn.L1Loss() # Standard pixel-wise L1 loss
def forward(self, predicted_output, target_output):
"""
Calculates the combined loss.
Args:
predicted_output (torch.Tensor): The output tensor from the network.
Shape: (N, C, H, W). Assumes C=3 (u, v, w)
or at least C=2 (u, v).
target_output (torch.Tensor): The ground truth high-resolution tensor.
Shape: (N, C, H, W).
Returns:
torch.Tensor: A scalar tensor representing the combined loss.
"""
# 1. Calculate standard L1 pixel-wise loss
pixel_loss = self.l1_loss(predicted_output, target_output)
# 2. Calculate the physics-informed divergence loss
# This only depends on the predicted velocities (u, v)
div_loss = divergence_loss(predicted_output,
dx=self.dx,
dy=self.dy,
loss_type=self.div_loss_type,
boundary_mode=self.boundary_mode)
# 3. Combine the losses
total_loss = pixel_loss + self.lambda_div * div_loss
return total_loss
L1 + Div Loss + Vort Loss¶
In [ ]:
class L1DivVortLoss(nn.Module):
"""
Calculates a combined loss:
L1 Loss + lambda_div * Divergence Loss + lambda_vort * Vorticity Consistency Loss
"""
def __init__(self, lambda_div=0.01, lambda_vort=0.01,
dx=2.0*math.pi/2048.0, dy=2.0*math.pi/2048.0,
div_loss_type='l2', vort_loss_type='l2',
boundary_mode='zeros'):
"""
Args:
lambda_div (float): Weighting factor for the divergence loss term.
lambda_vort (float): Weighting factor for the vorticity consistency loss term.
dx (float): Grid spacing in the x-direction for physics losses.
dy (float): Grid spacing in the y-direction for physics losses.
div_loss_type (str): Type of norm ('l1' or 'l2') for divergence loss.
vort_loss_type (str): Type of norm ('l1' or 'l2') for vorticity consistency loss.
boundary_mode (str): Boundary mode ('zeros', 'reflect', 'replicate')
for derivative calculations.
"""
super().__init__()
self.lambda_div = lambda_div
self.lambda_vort = lambda_vort
self.dx = dx
self.dy = dy
self.div_loss_type = div_loss_type
self.vort_loss_type = vort_loss_type
self.boundary_mode = boundary_mode
# Standard pixel-wise L1 loss component
self.l1_loss = nn.L1Loss()
def forward(self, predicted_output, target_output):
"""
Calculates the combined loss.
Args:
predicted_output (torch.Tensor): The output tensor from the network.
Shape: (N, 3, H, W) (u, v, w).
target_output (torch.Tensor): The ground truth high-resolution tensor.
Shape: (N, 3, H, W).
Returns:
torch.Tensor: A scalar tensor representing the combined loss.
dict: A dictionary containing the individual loss components (optional, for logging).
"""
if predicted_output.shape[1] != 3:
raise ValueError(f"Expected 3 channels (u, v, w) for predicted_output, got {predicted_output.shape[1]}")
if target_output.shape[1] != 3:
raise ValueError(f"Expected 3 channels (u, v, w) for target_output, got {target_output.shape[1]}")
if predicted_output.shape != target_output.shape:
raise ValueError(f"Predicted shape {predicted_output.shape} must match target shape {target_output.shape}")
# 1. Calculate standard L1 pixel-wise loss
pixel_loss = self.l1_loss(predicted_output, target_output)
# 2. Calculate the divergence loss (uses predicted u, v)
div_loss = divergence_loss(predicted_output,
dx=self.dx,
dy=self.dy,
loss_type=self.div_loss_type,
boundary_mode=self.boundary_mode)
# 3. Calculate the vorticity consistency loss (uses predicted u, v, w)
vort_cons_loss = vorticity_consistency_loss(predicted_output,
dx=self.dx,
dy=self.dy,
loss_type=self.vort_loss_type,
boundary_mode=self.boundary_mode)
# 4. Combine the losses with weighting factors
total_loss = pixel_loss + self.lambda_div * div_loss + self.lambda_vort * vort_cons_loss
# Optional: Return individual components for logging/monitoring
loss_components = {
'total_loss': total_loss.item(),
'l1_loss': pixel_loss.item(),
'divergence_loss': div_loss.item(),
'vorticity_consistency_loss': vort_cons_loss.item()
}
return total_loss
Evaluation Function¶
PSNR¶
In [ ]:
def psnr(target: torch.Tensor, prediction: torch.Tensor, data_range: float = None) -> float:
"""
Calculates the Peak Signal-to-Noise Ratio (PSNR) between two tensors.
Args:
target (torch.Tensor): The ground truth tensor (e.g., high-resolution image).
Shape: (N, C, H, W) or (C, H, W) or (H, W).
prediction (torch.Tensor): The predicted tensor (e.g., super-resolved image).
Must have the same shape as target.
data_range (float, optional): The range of the input data (maximum value - minimum value).
If None, it will be estimated from the target tensor
as max(target) - min(target).
Returns:
float: The PSNR value in dB. Returns +inf if the prediction and target are identical.
Returns NaN if inputs are invalid.
"""
if not isinstance(target, torch.Tensor) or not isinstance(prediction, torch.Tensor):
raise TypeError(f"Inputs must be torch.Tensor, got {type(target)} and {type(prediction)}")
if target.shape != prediction.shape:
raise ValueError(f"Input shapes must match, got {target.shape} and {prediction.shape}")
# Calculate Mean Squared Error (MSE)
# Reduce over all dimensions except batch (if present)
reduce_dims = tuple(range(1, target.dim())) # Dims C, H, W if N, C, H, W
if not reduce_dims: # Handle case of single value tensors (or just H, W)
reduce_dims = None
mse = torch.mean((target - prediction) ** 2, dim=reduce_dims)
# Handle case where MSE is zero (perfect reconstruction)
if torch.all(mse == 0):
return float('inf')
# Handle potential batch dimension (average MSE across batch if needed)
if mse.dim() > 0 and mse.shape[0] > 1: # Check if it's a batch MSE tensor
mse = torch.mean(mse)
elif mse.dim() > 0: # Single element tensor
mse = mse.item()
else: # Scalar already
mse = mse.item()
if mse == 0: # Double check after potential averaging
return float('inf')
# Determine data range (MAX_I value)
if data_range is None:
# Estimate from target data. Be cautious: this might vary batch to batch.
# It's often better to specify a fixed data_range based on dataset knowledge.
max_val = torch.max(target)
min_val = torch.min(target)
_data_range = (max_val - min_val).item()
if _data_range == 0: # Handle constant image case
# Avoid division by zero if data range is zero and mse > 0
# This scenario is unlikely with real data but possible
return -10.0 * math.log10(mse) if mse > 0 else float('inf')
else:
_data_range = data_range
# Calculate PSNR
try:
psnr_val = 10.0 * math.log10(_data_range**2 / mse)
except ValueError:
# Can happen if mse is negative (shouldn't be) or zero (already handled)
return float('nan') # Indicate an issue
return psnr_val
In [ ]:
def evaluate_psnr(model, dataloader, device):
model.eval()
total_psnr = 0.0
with torch.no_grad():
for batch in dataloader:
lr = batch['lq'].to(device)
hr = batch['gt'].to(device)
outputs = model(lr)
current_psnr = psnr(hr, outputs)
total_psnr += current_psnr
return total_psnr / len(dataloader)
SSIM¶
In [ ]:
import torch
import numpy as np
from skimage.metrics import structural_similarity
# Evaluation function using SSIM
def evaluate_ssim(model, dataloader, device):
model.eval()
total_ssim = 0.0
total_samples = 0
with torch.no_grad():
for batch in dataloader:
# Move data to the specified device
lr = batch['lq'].to(device)
hr = batch['gt'].to(device) # Ground truth HR
# Get model output
outputs = model(lr) # Predicted HR
# Move tensors to CPU to use with numpy
hr_np = hr.cpu().numpy()
outputs_np = outputs.cpu().numpy()
if hr_np.ndim == 4 and hr_np.shape[1] == 3:
hr_np = np.transpose(hr_np, (0, 2, 3, 1)) # -> NHWC
outputs_np = np.transpose(outputs_np, (0, 2, 3, 1)) # -> NHWC
channel_axis_param = -1 # Channels are now the last axis
else:
# Handle other cases if necessary (e.g., grayscale, different channel count)
# If grayscale (N, 1, H, W), you might squeeze the channel dim
# and set multichannel=False, channel_axis=None
print("Warning: Unexpected data shape. Assuming non-standard format.")
# Default behavior might be needed here based on your exact data shape
channel_axis_param = None # Or adjust as needed
# Iterate through samples in the batch
batch_size = hr_np.shape[0]
for i in range(batch_size):
hr_sample = hr_np[i]
output_sample = outputs_np[i]
# --- SSIM Calculation ---
# Determine the data range. If your data is normalized (e.g., to [0, 1]),
# set data_range=1.0. Otherwise, calculate from the ground truth HR.
# Using the actual range of the HR sample is often robust.
data_range = hr_sample.max() - hr_sample.min()
# Handle potential edge case where data_range is zero
if data_range == 0:
# If HR is constant, SSIM is 1 if output matches, 0 otherwise (or skip)
if np.all(hr_sample == output_sample):
ssim_val = 1.0
else:
# Or handle as a special case, e.g., assign 0 or skip
ssim_val = 0.0 # Or handle as needed
# print(f"Warning: Zero data range encountered for sample {i} in batch.")
else:
ssim_val = structural_similarity(
hr_sample,
output_sample,
multichannel=(channel_axis_param is not None), # True if channel axis is specified
data_range=data_range,
channel_axis=channel_axis_param # Specify channel axis if multichannel
)
total_ssim += ssim_val
total_samples += batch_size
# Calculate average SSIM over all samples
average_ssim = total_ssim / total_samples if total_samples > 0 else 0
return average_ssim
Scale = 4¶
Model Setup¶
!! Note: Patch size above 64 makes the Google Collab Runtime restart
In [ ]:
lambda_div = 0.00001
lambda_vort = 0.00001
patch_size = 64
In [ ]:
scale = 4 # Here we try with scaling of 4, next we try with 8 as well to observe the difference in scaling method
batch_size = 16
num_epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create datasets
split_idx = int(0.8 * len(u))
train_dataset = FluidFlowDataset(u[:split_idx], v[:split_idx], w[:split_idx], scale=scale, patch_size=patch_size)
val_dataset = FluidFlowDataset(u[split_idx:], v[split_idx:], w[split_idx:], scale=scale, patch_size=patch_size)
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# The swinIR model initialization
model = SwinIR(upscale=scale,
in_chans=3,
img_size=patch_size,
window_size=8,
img_range=1.,
depths=[6, 6, 6, 6],
embed_dim=60,
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffle',
resi_connection='1conv').to(device)
# loss and optimizer setup
criterion = L1DivVortLoss(lambda_div=lambda_div, lambda_vort=lambda_vort, div_loss_type='l2', vort_loss_type='l2')
optimizer = optim.Adam(model.parameters(), lr=1e-4)
Train¶
In [ ]:
best_val_loss = float('inf')
train_losses = []
val_losses = []
model_save_name = f'best_swinir_model_{scale}_div{lambda_div}_vort{lambda_vort}_patch{patch_size}.pth'
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for batch in train_loader:
lr = batch['lq'].to(device)
hr = batch['gt'].to(device)
optimizer.zero_grad()
# forward pass
outputs = model(lr)
loss = criterion(outputs, hr)
# backward pass and optimize
loss.backward()
optimizer.step()
train_loss += loss.item()
# validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
lr = batch['lq'].to(device)
hr = batch['gt'].to(device)
outputs = model(lr)
loss = criterion(outputs, hr)
val_loss += loss.item()
outputs_np = outputs.detach().cpu().numpy()
hr_np = hr.detach().cpu().numpy()
# Print statistics
train_loss /= len(train_loader)
val_loss /= len(val_loader)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), model_save_name)
# Store the model to Google Drive
!cp {model_save_name} /content/drive/MyDrive/CAKRes/Kapildev-WorkingSpace/{model_save_name} # Copy model to Drive
# Load best model
model.load_state_dict(torch.load(model_save_name))
Epoch [1/50], Train Loss: 2.6025, Val Loss: 2.5769 Epoch [2/50], Train Loss: 2.3228, Val Loss: 2.4208 Epoch [3/50], Train Loss: 2.1617, Val Loss: 1.8134 Epoch [4/50], Train Loss: 1.5935, Val Loss: 1.3608 Epoch [5/50], Train Loss: 1.3329, Val Loss: 1.2230 Epoch [6/50], Train Loss: 1.2122, Val Loss: 1.0991 Epoch [7/50], Train Loss: 1.1144, Val Loss: 1.0778 Epoch [8/50], Train Loss: 1.0168, Val Loss: 1.0314 Epoch [9/50], Train Loss: 0.9603, Val Loss: 0.9553 Epoch [10/50], Train Loss: 0.9343, Val Loss: 0.8848 Epoch [11/50], Train Loss: 0.8708, Val Loss: 0.8446 Epoch [12/50], Train Loss: 0.8346, Val Loss: 0.8267 Epoch [13/50], Train Loss: 0.8376, Val Loss: 0.8518 Epoch [14/50], Train Loss: 0.8148, Val Loss: 0.7875 Epoch [15/50], Train Loss: 0.7785, Val Loss: 0.7392 Epoch [16/50], Train Loss: 0.7485, Val Loss: 0.7062 Epoch [17/50], Train Loss: 0.7240, Val Loss: 0.7212 Epoch [18/50], Train Loss: 0.7177, Val Loss: 0.7543 Epoch [19/50], Train Loss: 0.7151, Val Loss: 0.7116 Epoch [20/50], Train Loss: 0.7054, Val Loss: 0.6648 Epoch [21/50], Train Loss: 0.6782, Val Loss: 0.6480 Epoch [22/50], Train Loss: 0.6186, Val Loss: 0.5985 Epoch [23/50], Train Loss: 0.6461, Val Loss: 0.5862 Epoch [24/50], Train Loss: 0.6526, Val Loss: 0.6391 Epoch [25/50], Train Loss: 0.6382, Val Loss: 0.6235 Epoch [26/50], Train Loss: 0.6265, Val Loss: 0.6035 Epoch [27/50], Train Loss: 0.6155, Val Loss: 0.6190 Epoch [28/50], Train Loss: 0.6057, Val Loss: 0.5864 Epoch [29/50], Train Loss: 0.6230, Val Loss: 0.5642 Epoch [30/50], Train Loss: 0.6189, Val Loss: 0.5994 Epoch [31/50], Train Loss: 0.6058, Val Loss: 0.6068 Epoch [32/50], Train Loss: 0.5862, Val Loss: 0.5855 Epoch [33/50], Train Loss: 0.6167, Val Loss: 0.5803 Epoch [34/50], Train Loss: 0.6030, Val Loss: 0.5708 Epoch [35/50], Train Loss: 0.6136, Val Loss: 0.5823 Epoch [36/50], Train Loss: 0.6150, Val Loss: 0.6444 Epoch [37/50], Train Loss: 0.5550, Val Loss: 0.5782 Epoch [38/50], Train Loss: 0.5735, Val Loss: 0.5456 Epoch [39/50], Train Loss: 0.5783, Val Loss: 0.6031 Epoch [40/50], Train Loss: 0.5674, Val Loss: 0.5730 Epoch [41/50], Train Loss: 0.5493, Val Loss: 0.5551 Epoch [42/50], Train Loss: 0.5579, Val Loss: 0.5541 Epoch [43/50], Train Loss: 0.5852, Val Loss: 0.5121 Epoch [44/50], Train Loss: 0.5726, Val Loss: 0.5564 Epoch [45/50], Train Loss: 0.5402, Val Loss: 0.5497 Epoch [46/50], Train Loss: 0.5754, Val Loss: 0.5403 Epoch [47/50], Train Loss: 0.5457, Val Loss: 0.5337 Epoch [48/50], Train Loss: 0.5545, Val Loss: 0.5528 Epoch [49/50], Train Loss: 0.5566, Val Loss: 0.5643 Epoch [50/50], Train Loss: 0.5623, Val Loss: 0.5587
Out[ ]:
<All keys matched successfully>
Convergence Curve¶
In [ ]:
plt.figure(figsize=(15, 5))
# Loss plot
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
Out[ ]:
<matplotlib.legend.Legend at 0x782922983410>
Evaluation¶
In [ ]:
model.load_state_dict(torch.load(f'best_swinir_model_{scale}_div{lambda_div}_vort{lambda_vort}_patch{patch_size}.pth'))
# Evaluation function
def evaluate(model, dataloader):
model.eval()
total_loss = 0.0
with torch.no_grad():
for batch in dataloader:
lr = batch['lq'].to(device)
hr = batch['gt'].to(device)
outputs = model(lr)
loss = criterion(outputs, hr)
total_loss += loss.item()
return total_loss / len(dataloader)
# Evaluate on validation set
val_loss = evaluate(model, eval_loader)
print(f'Final Test Loss: {val_loss:.4f}')
# With no div loss: Final Validation Loss: 1.2181
# With div loss and coeff = 1: Final Validation Loss: 410.3817
# Now with coeff = 0.1: 44.1709
# Now with coeff = 0.01: 7.2212
# Use L1 criterion: 2.3576, MSE Loss: 22.6876
# Using div loss = 0, Final Validation Loss: 1.1393
# With div coeff = 0.01 and vort coeff = 0.01: 11.4714
# With div coeff = 0.00001 and vort loss coeff = 0.00001: 1.2598
# With scale = 2, div and loss coeff = 0.00001 : 0.8166
# With scale = 2, coeffs = 0: ___
# With all dataset, scale = 4, coeff = 0.00001 : 0.5137
# With train on 1000 dataset; eval on 100 test: 0.5247
Final Test Loss: 0.5471
Use PSNR¶
In [ ]:
model.load_state_dict(torch.load(f'best_swinir_model_{scale}_div{lambda_div}_vort{lambda_vort}_patch{patch_size}.pth'))
val_psnr = evaluate_psnr(model, eval_loader, device)
print(f'Average Test PSNR: {val_psnr:.4f}')
# With div loss and coeff = 0.01, 21.7755
# With no div loss: 27.3784 (so better with no div loss)
# With div loss coeff = 0.01 and vort loss coeff = 0.01: 22.2877
# With div coeff = 0.00001 and vort loss coeff = 0.00001: 26.4322
# With scale = 2, div and loss coeff = 0.00001 : 30.1367
# With scale = 2, coeffs = 0: ___
# With all dataset, scale = 4, coeff = 0.00001 : 35.8653
# With train on 1000 dataset; eval on 100 test: 36.2829
Average Test PSNR: 36.5077
Use SSIM¶
In [ ]:
model.load_state_dict(torch.load(f'best_swinir_model_{scale}_div{lambda_div}_vort{lambda_vort}_patch{patch_size}.pth'))
# Evaluate on validation set
# Ensure model is on the correct device before evaluation
model.to(device)
val_ssim = evaluate_ssim(model, eval_loader, device)
print(f'Average Test SSIM: {val_ssim:.4f}')
# Previous output: Average Validation SSIM: 0.7489 (no divergence loss used)
# Now: Average Validation SSIM: 0.1618 (divergence loss used, with divergence coeff=1)
# Now with div loss and coeff=0.1: 0.1747
# Now with coeff=0.01: 0.2772
# Using div loss coeff = 0: 0.6623
# With div loss coeff = 0.01 and vort loss coeff = 0.01: 0.2131
# With div coeff = 0.00001 and vort loss coeff = 0.00001: 0.5488
# With scale = 2, div and loss coeff = 0.00001 : 0.8226
# With scale = 2, coeffs = 0: 0.7665
# With all dataset, scale = 4, coeff = 0.00001 : 0.9165
# With train on 1000 dataset; eval on 100 test: 0.9189
Average Test SSIM: 0.9214
Visualize¶
In [ ]:
def visualize_results(model, dataloader, num_samples=3):
model.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
if i >= num_samples:
break
lr = batch['lq'].to(device)
hr = batch['gt'].to(device)
# Get model prediction
sr = model(lr)
# Convert to numpy
lr_np = tensor2img(lr)
hr_np = tensor2img(hr)
sr_np = tensor2img(sr)
# Plot
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
for j in range(3):
axes[j][0].imshow(lr_np[..., j], cmap='coolwarm') # Show u component
axes[j][0].set_title('Low Resolution')
axes[j][1].imshow(sr_np[..., j], cmap='coolwarm')
axes[j][1].set_title('Super Resolved')
axes[j][2].imshow(hr_np[..., j], cmap='coolwarm')
axes[j][2].set_title('High Resolution')
plt.show()
On validation dataset¶
In [ ]:
# Ran with: all dataset; scale = 4; div,vort = 0.00001
model.load_state_dict(torch.load(f'best_swinir_model_4_div1e-05_vort1e-05_patch64.pth'))
# Visualize some results
visualize_results(model, val_loader)
# Save the model for later use
torch.save(model.state_dict(), f'swinir_fluid_flow_4_div0.00001_vort0.00001.pth')
On evaluation dataset¶
In [ ]:
# Ran with: all dataset; scale = 4; div,vort = 0.00001
model.load_state_dict(torch.load(f'best_swinir_model_4_div1e-05_vort1e-05_patch64.pth'))
# Visualize some results
visualize_results(model, eval_loader)
Previous¶
In [ ]:
# Ran with scale = 2; div,vort = 0.00001
model.load_state_dict(torch.load(f'best_swinir_model_2_div0.00001_vort0.00001.pth'))
# Visualize some results
visualize_results(model, val_loader)
In [ ]:
# Ran with scale = 2; div,vort coeff = 0
model.load_state_dict(torch.load(f'best_swinir_model_2_div0_vort0.pth'))
# Visualize some results
visualize_results(model, val_loader)
# Save the model for later use
torch.save(model.state_dict(), f'swinir_fluid_flow_{scale}.pth')